import argparse
from model import CQDTransEA
from datasets import Dataset
from regularizers import N2, N3
from torch import optim
from optimizers import LP_optimizer
from datasets import *
from torch.utils.data import DataLoader
from loss import CELoss,MAELoss
import os
import time
import json
from torch.utils.tensorboard import SummaryWriter

if __name__ == "__main__":
    
    modes = ['train', 'load']
    datasets = ['FB15K']
    
    
    parser = argparse.ArgumentParser(
		description="Relational learning contraption"
	)
    parser.add_argument('--rank',type=int,default=300,help="embedding size")
    parser.add_argument("--regularizer",type=str,default="N3")
    parser.add_argument('--learning_rate', type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--max_epochs",type=int,default=1)
    parser.add_argument("--batch_size",type=int,default=1024)
    parser.add_argument("--data_path",type=str,default="/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/line_predictor_model/preprocessed_data_for_linepredictor.pkl")
    parser.add_argument("--reg",type=float,default=0.01)
    parser.add_argument("--alpha",type=float,default=0.3,help="the ratio of tow loss")
    parser.add_argument("--use_cuda",action="store_true")
    parser.add_argument("--valid_and_test",type=int,default=3)
    parser.add_argument("--dataset_name",type=str,default="FB15K")
    args = parser.parse_args()
    
    dataset = Dataset(args.data_path)
    
    
    model=CQDTransEA(
        dataset.entity_num,
        args.rank,
        dataset.att_num,
        dataset.relation_num,
        p_norm=1,
        use_attributes=True,
        do_sigmoid=True,
    )
    
    regularizer = {
        'N2': N2(args.reg),
        'N3': N3(args.reg),
    }[args.regularizer]
    
    if args.use_cuda:
        device = 'cuda:4'
    else:
        device= 'cpu'
    model.to(device)
    #优化方法
    optim_method=optim.Adagrad(model.parameters(), lr=args.learning_rate)
    #损失函数
    rel_loss = CELoss(10)
    att_loss = MAELoss(0)
    
    #优化器
    lp_optimizer = LP_optimizer(model,rel_loss,att_loss,regularizer, optim_method, args.batch_size,args.alpha)
    
    
    ####train
    
    ##创建dataloader
    entity_dataset=EntityDataset(args.data_path,device,"train")
    attribute_dataset=AttributeDataset(args.data_path,device,"train")
    
    batch_size_entity=args.batch_size
    nbatches = (entity_dataset.len//batch_size_entity) + (entity_dataset.len % batch_size_entity > 0)
    batch_size_att = (attribute_dataset.len // nbatches) + (attribute_dataset.len  % nbatches > 0)
    
    entity_train_dataloader=DataLoader(
        dataset=entity_dataset,
        batch_size=batch_size_entity,
        shuffle=True,
        collate_fn=EntityDataset.collate_fn
    )
    att_train_dataloader=DataLoader(
        dataset=attribute_dataset,
        batch_size=batch_size_att,
        shuffle=True,
        collate_fn=AttributeDataset.collate_fn
    )
    
    writer = SummaryWriter('/home/pxs/projects/paper_nrn_new/model_proposed_accelerate_numpre_task2/line_predictor_model/value_prediction_test_outcome')
    
    ##debug
    #model.predict_all_value(torch.tensor(4))
    #model.score_attribution_exists(torch.tensor(4))
    
    
    for epoch in range(1,args.max_epochs+1):
        if epoch%args.valid_and_test==0 or epoch==1:
            
            ##测试属性预测
            all_att_hits10_train,all_att_hits10_test,all_att_mrr_train,all_att_mrr_test=dataset.att_eval(device,model)
            writer.add_scalar('all_att_hits10_train', all_att_hits10_train, epoch)
            writer.add_scalar('all_att_hits10_test', all_att_hits10_test, epoch)
            writer.add_scalar('all_att_mrr_train', all_att_mrr_train, epoch)
            writer.add_scalar('all_att_mrr_test', all_att_mrr_test, epoch)
            
            ##测试数值推理
            valid,test,train=[
                dataset.num_eval(device,model,split,-1 if split != 'train' else 5000)
                for split in ["valid","test","train"]
            ]
            writer.add_scalar('train_mae', train[0], epoch)
            writer.add_scalar('train_mse', train[1], epoch)
            writer.add_scalar('train_rmse', train[2], epoch)
            writer.add_scalar('valid_mae', valid[0], epoch)
            writer.add_scalar('valid_mse', valid[1], epoch)
            writer.add_scalar('valid_rmse', valid[2], epoch)
            writer.add_scalar('test_mae', test[0], epoch)
            writer.add_scalar('test_mse', test[1], epoch)
            writer.add_scalar('test_rmse', test[2], epoch)
            

            
            
        lp_optimizer.train_epoch(epoch,entity_train_dataloader,att_train_dataloader)
    writer.close()
    
    
    print("train over,saving model...")
    timestamp = str(int(time.time()))
    model_dir = os.path.join(os.getcwd())
    torch.save(
        {
            "entity_num":dataset.entity_num,
            "rank":args.rank,
            "att_num":dataset.att_num,
            "relation_num":dataset.relation_num,
            "p_norm":1,
            "use_attributes":True,
            "do_sigmoid":True,
            "model_state_dict":lp_optimizer.model.state_dict()
        },
        os.path.join(model_dir,"epoch_{}_{}".format(args.max_epochs,timestamp))
    )
    with open(os.path.join(model_dir,'model-metadata-{}.json'.format(timestamp)), 'w') as json_file:
        json.dump(vars(args), json_file)
    
    print("work done!")